{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Environment Setup\n",
    "\n",
    "To start working with Flower, very little is required once you have activated your Python environment (e.g. via `conda`, `virtualenv`, `pyenv`, etc). If you are running this code on Colab, there is really nothing to do except to install Flower and other dependencies. The steps below have been verified to run in Colab.\n",
    "\n",
    "## Installing Flower\n",
    "\n",
    "You can install flower very conveniently from `pip`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# depending on your shell, you might need to add `\\` before `[` and `]`.\n",
    "!pip install -q flwr[simulation]\n",
    "!pip install flwr_datasets[vision]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.ai/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart-_ example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory.\n",
    "\n",
    "In this tutorial we are going to use PyTorch, it comes pre-installed in your Collab runtime so there is no need to installed it again. If you wouuld like to install another version, you can still do that in the same way other packages are installed via `!pip`"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to install some other dependencies you are likely familiar with. Let's install `maplotlib` to plot our results at the end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "58b7af77-609f-4118-bd5b-5629a4b5a296"
   },
   "outputs": [],
   "source": [
    "!pip install matplotlib"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparing the experiment\n",
    "\n",
    "This tutorial is not so much about novel architectural designs so we keep things simple and make use of a typical CNN that is adequate for the MNIST image classification task.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, num_classes: int) -> None:\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, num_classes)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be training the model in a Federated setting. In order to do that, we need to define two functions:\n",
    "\n",
    "* `train()` that will train the model given a dataloader.\n",
    "* `test()` that will be used to evaluate the performance of the model on held-out data, e.g., a training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, trainloader, optim, epochs, device: str):\n",
    "    \"\"\"Train the network on the training set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    net.train()\n",
    "    for _ in range(epochs):\n",
    "        for batch in trainloader:\n",
    "            images, labels = batch[\"image\"].to(device), batch[\"label\"].to(device)\n",
    "            optim.zero_grad()\n",
    "            loss = criterion(net(images), labels)\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "\n",
    "def test(net, testloader, device: str):\n",
    "    \"\"\"Validate the network on the entire test set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    correct, loss = 0, 0.0\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data[\"image\"].to(device), data[\"label\"].to(device)\n",
    "            outputs = net(images)\n",
    "            loss += criterion(outputs, labels).item()\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    accuracy = correct / len(testloader.dataset)\n",
    "    return loss, accuracy"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code we have written so far is not specific to Federated Learning. Then, what are the key differences between Federated Learning and Centralised Training? If you could only pick you, probably you'd say:\n",
    "* Federated Learning is distributed -- the model is trained on-device by the participating clients.\n",
    "* Data remains private and is owned by a specific _client_ -- the data is never sent to the central server.\n",
    "\n",
    "The are several more differences. But the above two are the main ones to always consider and that are common to all flavours of Federated Learning (e.g. _cross-device_ or _cross-silo_). The remaining of this tutorial is going to focus in transforming the code we have written so far for the centralised setting and construct a Federated Learning pipeline using Flower and PyTorch.\n",
    "\n",
    "Let's begin! 🚀"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One Client, One Data Partition\n",
    "\n",
    "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n",
    "\n",
    "We can use [Flower Datasets](https://flower.ai/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "from flwr_datasets import FederatedDataset\n",
    "from datasets.utils.logging import disable_progress_bar\n",
    "\n",
    "# Let's set a simulation involving a total of 100 clients\n",
    "NUM_CLIENTS = 100\n",
    "\n",
    "# Download MNIST dataset and partition the \"train\" partition (so one can be assigned to each client)\n",
    "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n",
    "# Let's keep the test set as is, and use it to evaluate the global model on the server\n",
    "centralized_testset = mnist_fds.load_split(\"test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a function that returns a set of transforms to apply to our images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import ToTensor, Normalize, Compose\n",
    "\n",
    "\n",
    "def apply_transforms(batch):\n",
    "    \"\"\"Get transformation for MNIST dataset\"\"\"\n",
    "\n",
    "    # transformation to convert images to tensors and apply normalization\n",
    "    transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n",
    "    batch[\"image\"] = [transforms(img) for img in batch[\"image\"]]\n",
    "    return batch"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's next define how our FL clients will behave.\n",
    "\n",
    "## Defining a Flower Client\n",
    "\n",
    "You can think of a client in FL as an entity that owns some data and trains a model using this data. The caveat is that the model is being trained _collaboratively_ in Federation by multiple clients (sometimes up to hundreds of thousands) and, in most instances of FL, is sent by a central server.\n",
    "\n",
    "A Flower Client is a simple Python class with four distinct methods:\n",
    "\n",
    "* `fit()`: With this method, the client does on-device training for a number of epochs using its own data. At the end, the resulting model is sent back to the server for aggregation.\n",
    "\n",
    "* `evaluate()`: With this method, the server can evaluate the performance of the global model on the local validation set of a client. This can be used for instance when there is no centralised dataset on the server for validation/test. Also, this method can be use to asses the degree of personalisation of the model being federated.\n",
    "\n",
    "* `set_parameters()`: This method takes the parameters sent by the server and uses them to initialise the parameters of the local model that is ML framework specific (e.g. TF, Pytorch, etc).\n",
    "\n",
    "* `get_parameters()`: It extract the parameters from the local model and transforms them into a list of NumPy arrays. This ML framework-agnostic representation of the model will be sent to the server.\n",
    "\n",
    "Let's start by importing Flower!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import flwr as fl"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's defice our Flower Client class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "from flwr.common import NDArrays, Scalar\n",
    "\n",
    "\n",
    "class FlowerClient(fl.client.NumPyClient):\n",
    "    def __init__(self, trainloader, valloader) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "        self.model = Net(num_classes=10)\n",
    "        # Determine device\n",
    "        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.model.to(self.device)  # send model to device\n",
    "\n",
    "    def set_parameters(self, parameters):\n",
    "        \"\"\"With the model parameters received from the server,\n",
    "        overwrite the uninitialise model in this class with them.\"\"\"\n",
    "\n",
    "        params_dict = zip(self.model.state_dict().keys(), parameters)\n",
    "        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
    "        # now replace the parameters\n",
    "        self.model.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "    def get_parameters(self, config: Dict[str, Scalar]):\n",
    "        \"\"\"Extract all model parameters and conver them to a list of\n",
    "        NumPy arryas. The server doesn't work with PyTorch/TF/etc.\"\"\"\n",
    "        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]\n",
    "\n",
    "    def fit(self, parameters, config):\n",
    "        \"\"\"This method train the model using the parameters sent by the\n",
    "        server on the dataset of this client. At then end, the parameters\n",
    "        of the locally trained model are communicated back to the server\"\"\"\n",
    "\n",
    "        # copy parameters sent by the server into client's local model\n",
    "        self.set_parameters(parameters)\n",
    "\n",
    "        # read from config\n",
    "        lr, epochs = config[\"lr\"], config[\"epochs\"]\n",
    "\n",
    "        # Define the optimizer\n",
    "        optim = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)\n",
    "\n",
    "        # do local training\n",
    "        train(self.model, self.trainloader, optim, epochs=epochs, device=self.device)\n",
    "\n",
    "        # return the model parameters to the server as well as extra info (number of training examples in this case)\n",
    "        return self.get_parameters({}), len(self.trainloader), {}\n",
    "\n",
    "    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):\n",
    "        \"\"\"Evaluate the model sent by the server on this client's\n",
    "        local validation set. Then return performance metrics.\"\"\"\n",
    "\n",
    "        self.set_parameters(parameters)\n",
    "        loss, accuracy = test(self.model, self.valloader, device=self.device)\n",
    "        # send statistics back to the server\n",
    "        return float(loss), len(self.valloader), {\"accuracy\": accuracy}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Spend a few minutes to inspect the `FlowerClient` class above. Please ask questions if there is something unclear !\n",
    "\n",
    "Then keen-eyed among you might have realised that if we were to fuse the client's `fit()` and `evaluate()` methods, we'll end up with essentially the same as in the `run_centralised()` function we used in the Centralised Training part of this tutorial. And it is true!! In Federated Learning, the way clients perform local training makes use of the same principles as more traditional centralised setup. The key difference is that the dataset now is much smaller and it's never _\"seen\"_ by the entity running the FL workload (i.e. the central server).\n",
    "\n",
    "\n",
    "Talking about the central server... we should define what strategy we want to make use of so the updated models sent from the clients back to the server at the end of the `fit()` method are aggregate.\n",
    "\n",
    "\n",
    "## Choosing a Flower Strategy\n",
    "\n",
    "\n",
    "A strategy sits at the core of the Federated Learning experiment. It is involved in all stages of a FL pipeline: sampling clients; sending the _global model_ to the clients so they can do `fit()`; receive the updated models from the clients and **aggregate** these to construct a new _global model_; define and execute global or federated evaluation; and more.\n",
    "\n",
    "Flower comes with [many strategies built-in](https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy) and more to be available in the next release (`1.5` already!). For this tutorial, let's use what is arguable the most popular strategy out there: `FedAvg`.\n",
    "\n",
    "The way `FedAvg` works is simple but performs surprisingly well in practice. It is therefore one good strategy to start your experimentation. `FedAvg`, as its name implies, derives a new version of the _global model_ by taking the average of all the models sent by clients participating in the round. You can read all the details [in the paper](https://arxiv.org/abs/1602.05629).\n",
    "\n",
    "Let's see how we can define `FedAvg` using Flower. We use one of the callbacks called `evaluate_fn` so we can easily evaluate the state of the global model using a small centralised testset. Note this functionality is user-defined since it requires a choice in terms of ML-framework. (if you recall, Flower is framework agnostic).\n",
    "\n",
    "> This being said, centralised evaluation of the global model is only possible if there exists a centralised dataset that somewhat follows a similar distribution as the data that's spread across clients. In some cases having such centralised dataset for validation is not possible, so the only solution is to federate the evaluation of the _global model_. This is the default behaviour in Flower. If you don't specify teh `evaluate_fn` argument in your strategy, then, centralised global evaluation won't be performed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_evaluate_fn(centralized_testset: Dataset):\n",
    "    \"\"\"This is a function that returns a function. The returned\n",
    "    function (i.e. `evaluate_fn`) will be executed by the strategy\n",
    "    at the end of each round to evaluate the stat of the global\n",
    "    model.\"\"\"\n",
    "\n",
    "    def evaluate_fn(server_round: int, parameters, config):\n",
    "        \"\"\"This function is executed by the strategy it will instantiate\n",
    "        a model and replace its parameters with those from the global model.\n",
    "        The, the model will be evaluate on the test set (recall this is the\n",
    "        whole MNIST test set).\"\"\"\n",
    "\n",
    "        model = Net(num_classes=10)\n",
    "\n",
    "        # Determine device\n",
    "        device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "        model.to(device)  # send model to device\n",
    "\n",
    "        # set parameters to the model\n",
    "        params_dict = zip(model.state_dict().keys(), parameters)\n",
    "        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
    "        model.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "        # Apply transform to dataset\n",
    "        testset = centralized_testset.with_transform(apply_transforms)\n",
    "\n",
    "        testloader = DataLoader(testset, batch_size=50)\n",
    "        # call test\n",
    "        loss, accuracy = test(model, testloader, device)\n",
    "        return loss, {\"accuracy\": accuracy}\n",
    "\n",
    "    return evaluate_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We could now define a strategy just as shown (commented) above. Instead, let's see how additional (but entirely optional) functionality can be easily added to our strategy. We are going to define two additional auxiliary functions to: (1) be able to configure how clients do local training; and (2) define a function to aggregate the metrics that clients return after running their `evaluate` methods:\n",
    "\n",
    "1. `fit_config()`. This is a function that will be executed inside the strategy when configuring a new `fit` round. This function is relatively simple and only requires as input argument the round at which the FL experiment is at. In this example we simply return a Python dictionary to specify the number of epochs and learning rate each client should made use of inside their `fit()` methods. A more versatile implementation would add more hyperparameters (e.g. the learning rate) and adjust them as the FL process advances (e.g. reducing the learning rate in later FL rounds).\n",
    "2. `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr.common import Metrics\n",
    "\n",
    "\n",
    "def fit_config(server_round: int) -> Dict[str, Scalar]:\n",
    "    \"\"\"Return a configuration with static batch size and (local) epochs.\"\"\"\n",
    "    config = {\n",
    "        \"epochs\": 1,  # Number of local epochs done by clients\n",
    "        \"lr\": 0.01,  # Learning rate to use by clients during fit()\n",
    "    }\n",
    "    return config\n",
    "\n",
    "\n",
    "def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:\n",
    "    \"\"\"Aggregation function for (federated) evaluation metrics, i.e. those returned by\n",
    "    the client's evaluate() method.\"\"\"\n",
    "    # Multiply accuracy of each client by number of examples used\n",
    "    accuracies = [num_examples * m[\"accuracy\"] for num_examples, m in metrics]\n",
    "    examples = [num_examples for num_examples, _ in metrics]\n",
    "\n",
    "    # Aggregate and return custom metric (weighted average)\n",
    "    return {\"accuracy\": sum(accuracies) / sum(examples)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can define our strategy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy = fl.server.strategy.FedAvg(\n",
    "    fraction_fit=0.1,  # Sample 10% of available clients for training\n",
    "    fraction_evaluate=0.05,  # Sample 5% of available clients for evaluation\n",
    "    on_fit_config_fn=fit_config,\n",
    "    evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics\n",
    "    evaluate_fn=get_evaluate_fn(centralized_testset),  # global evaluation function\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far we have:\n",
    "* created the dataset partitions (one for each client)\n",
    "* defined the client class\n",
    "* decided on a strategy to use\n",
    "\n",
    "Now we just need to launch the Flower FL experiment... not so fast! just one final function: let's create another callback that the Simulation Engine will use in order to span VirtualClients. As you can see this is really simple: construct a FlowerClient object, assigning each their own data partition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "def get_client_fn(dataset: FederatedDataset):\n",
    "    \"\"\"Return a function to construct a client.\n",
    "\n",
    "    The VirtualClientEngine will execute this function whenever a client is sampled by\n",
    "    the strategy to participate.\n",
    "    \"\"\"\n",
    "\n",
    "    def client_fn(cid: str) -> fl.client.Client:\n",
    "        \"\"\"Construct a FlowerClient with its own dataset partition.\"\"\"\n",
    "\n",
    "        # Let's get the partition corresponding to the i-th client\n",
    "        client_dataset = dataset.load_partition(int(cid), \"train\")\n",
    "\n",
    "        # Now let's split it into train (90%) and validation (10%)\n",
    "        client_dataset_splits = client_dataset.train_test_split(test_size=0.1, seed=42)\n",
    "\n",
    "        trainset = client_dataset_splits[\"train\"]\n",
    "        valset = client_dataset_splits[\"test\"]\n",
    "\n",
    "        # Now we apply the transform to each batch.\n",
    "        trainloader = DataLoader(\n",
    "            trainset.with_transform(apply_transforms), batch_size=32, shuffle=True\n",
    "        )\n",
    "        valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n",
    "\n",
    "        # Create and return client\n",
    "        return FlowerClient(trainloader, valloader).to_client()\n",
    "\n",
    "    return client_fn\n",
    "\n",
    "\n",
    "client_fn_callback = get_client_fn(mnist_fds)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to launch the FL experiment using Flower simulation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "9ad8dcea-8004-4c6e-a025-e168da636c88"
   },
   "outputs": [],
   "source": [
    "# With a dictionary, you tell Flower's VirtualClientEngine that each\n",
    "# client needs exclusive access to these many resources in order to run\n",
    "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
    "\n",
    "# Let's disable tqdm progress bar in the main thread (used by the server)\n",
    "disable_progress_bar()\n",
    "\n",
    "history = fl.simulation.start_simulation(\n",
    "    client_fn=client_fn_callback,  # a callback to construct a client\n",
    "    num_clients=NUM_CLIENTS,  # total number of clients in the experiment\n",
    "    config=fl.server.ServerConfig(num_rounds=10),  # let's run for 10 rounds\n",
    "    strategy=strategy,  # the strategy that will orchestrate the whole FL pipeline\n",
    "    client_resources=client_resources,\n",
    "    actor_kwargs={\n",
    "        \"on_actor_init_fn\": disable_progress_bar  # disable tqdm on each actor/process spawning virtual clients\n",
    "    },\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Doing 10 rounds should take less than 2 minutes on a CPU-only Colab instance <-- Flower Simulation is fast! 🚀\n",
    "\n",
    "You can then use the resturned `History` object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the _global model_. This is want the function `evaluate_fn()` that we passed to the strategy reports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 508
    },
    "outputId": "d8eab106-cee9-4266-9082-0944882cdba8"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(f\"{history.metrics_centralized = }\")\n",
    "\n",
    "global_accuracy_centralised = history.metrics_centralized[\"accuracy\"]\n",
    "round = [data[0] for data in global_accuracy_centralised]\n",
    "acc = [100.0 * data[1] for data in global_accuracy_centralised]\n",
    "plt.plot(round, acc)\n",
    "plt.grid()\n",
    "plt.ylabel(\"Accuracy (%)\")\n",
    "plt.xlabel(\"Round\")\n",
    "plt.title(\"MNIST - IID - 100 clients with 10 clients per round\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations! With that, you built a Flower client, customized it's instantiation through the `client_fn`, customized the server-side execution through a `FedAvg` strategy configured for this workload, and started a simulation with 100 clients (each holding their own individual partition of the MNIST dataset).\n",
    "\n",
    "Next, you can continue to explore more advanced Flower topics:\n",
    "\n",
    "- Deploy server and clients on different machines using `start_server` and `start_client`\n",
    "- Customize the server-side execution through custom strategies\n",
    "- Customize the client-side execution through `config` dictionaries\n",
    "\n",
    "Get all resources you need!\n",
    "\n",
    "* **[DOCS]** Our complete documenation: https://flower.ai/docs/\n",
    "* **[Examples]** All Flower examples: https://flower.ai/docs/examples/\n",
    "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n",
    "\n",
    "Don't forget to join our Slack channel: https://flower.ai/join-slack/\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
